summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/canonicalize.py10
-rw-r--r--sqlglot/optimizer/eliminate_ctes.py39
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py19
-rw-r--r--sqlglot/optimizer/expand_laterals.py34
-rw-r--r--sqlglot/optimizer/expand_multi_table_selects.py24
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py2
-rw-r--r--sqlglot/optimizer/lower_identities.py88
-rw-r--r--sqlglot/optimizer/merge_subqueries.py17
-rw-r--r--sqlglot/optimizer/normalize.py35
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py36
-rw-r--r--sqlglot/optimizer/optimize_joins.py7
-rw-r--r--sqlglot/optimizer/optimizer.py37
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py42
-rw-r--r--sqlglot/optimizer/pushdown_projections.py9
-rw-r--r--sqlglot/optimizer/qualify.py80
-rw-r--r--sqlglot/optimizer/qualify_columns.py221
-rw-r--r--sqlglot/optimizer/qualify_tables.py45
-rw-r--r--sqlglot/optimizer/scope.py43
-rw-r--r--sqlglot/optimizer/simplify.py32
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py23
20 files changed, 451 insertions, 392 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index ef929ac..da2fce8 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -3,10 +3,9 @@ from __future__ import annotations
import itertools
from sqlglot import exp
-from sqlglot.helper import should_identify
-def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
+def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
@@ -14,19 +13,14 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr
Args:
expression: The expression to canonicalize.
- identify: Whether or not to force identify identifier.
"""
- exp.replace_children(expression, canonicalize, identify=identify)
+ exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
- if isinstance(expression, exp.Identifier):
- if should_identify(expression.this, identify):
- expression.set("quoted", True)
-
return expression
diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py
index 7b862c6..6f1865c 100644
--- a/sqlglot/optimizer/eliminate_ctes.py
+++ b/sqlglot/optimizer/eliminate_ctes.py
@@ -19,24 +19,25 @@ def eliminate_ctes(expression):
"""
root = build_scope(expression)
- ref_count = root.ref_count()
-
- # Traverse the scope tree in reverse so we can remove chains of unused CTEs
- for scope in reversed(list(root.traverse())):
- if scope.is_cte:
- count = ref_count[id(scope)]
- if count <= 0:
- cte_node = scope.expression.parent
- with_node = cte_node.parent
- cte_node.pop()
-
- # Pop the entire WITH clause if this is the last CTE
- if len(with_node.expressions) <= 0:
- with_node.pop()
-
- # Decrement the ref count for all sources this CTE selects from
- for _, source in scope.selected_sources.values():
- if isinstance(source, Scope):
- ref_count[id(source)] -= 1
+ if root:
+ ref_count = root.ref_count()
+
+ # Traverse the scope tree in reverse so we can remove chains of unused CTEs
+ for scope in reversed(list(root.traverse())):
+ if scope.is_cte:
+ count = ref_count[id(scope)]
+ if count <= 0:
+ cte_node = scope.expression.parent
+ with_node = cte_node.parent
+ cte_node.pop()
+
+ # Pop the entire WITH clause if this is the last CTE
+ if len(with_node.expressions) <= 0:
+ with_node.pop()
+
+ # Decrement the ref count for all sources this CTE selects from
+ for _, source in scope.selected_sources.values():
+ if isinstance(source, Scope):
+ ref_count[id(source)] -= 1
return expression
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index a39fe96..84f50e9 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -16,9 +16,9 @@ def eliminate_subqueries(expression):
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
- >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
- 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
+ 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Args:
expression (sqlglot.Expression): expression
@@ -32,6 +32,9 @@ def eliminate_subqueries(expression):
root = build_scope(expression)
+ if not root:
+ return expression
+
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
@@ -112,7 +115,7 @@ def _eliminate_union(scope, existing_ctes, taken):
# Try to maintain the selections
expressions = scope.selects
selects = [
- exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
+ exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
for e in expressions
if e.alias_or_name
]
@@ -120,7 +123,9 @@ def _eliminate_union(scope, existing_ctes, taken):
if len(selects) != len(expressions):
selects = ["*"]
- scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
+ scope.expression.replace(
+ exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
+ )
if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
@@ -131,6 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
+ # This ensures we don't drop the "pivot" arg from a pivoted subquery
+ if scope.parent.pivots:
+ return None
+
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
@@ -153,7 +162,7 @@ def _eliminate_cte(scope, existing_ctes, taken):
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
- new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
+ new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
table.replace(new_table)
return cte
diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py
deleted file mode 100644
index 5b2f706..0000000
--- a/sqlglot/optimizer/expand_laterals.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-from sqlglot import exp
-
-
-def expand_laterals(expression: exp.Expression) -> exp.Expression:
- """
- Expand lateral column alias references.
-
- This assumes `qualify_columns` as already run.
-
- Example:
- >>> import sqlglot
- >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
- >>> expression = sqlglot.parse_one(sql)
- >>> expand_laterals(expression).sql()
- 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
-
- Args:
- expression: expression to optimize
- Returns:
- optimized expression
- """
- for select in expression.find_all(exp.Select):
- alias_to_expression: t.Dict[str, exp.Expression] = {}
- for projection in select.expressions:
- for column in projection.find_all(exp.Column):
- if not column.table and column.name in alias_to_expression:
- column.replace(alias_to_expression[column.name].copy())
- if isinstance(projection, exp.Alias):
- alias_to_expression[projection.alias] = projection.this
- return expression
diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py
deleted file mode 100644
index 86f0c2d..0000000
--- a/sqlglot/optimizer/expand_multi_table_selects.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from sqlglot import exp
-
-
-def expand_multi_table_selects(expression):
- """
- Replace multiple FROM expressions with JOINs.
-
- Example:
- >>> from sqlglot import parse_one
- >>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql()
- 'SELECT * FROM x CROSS JOIN y'
- """
- for from_ in expression.find_all(exp.From):
- parent = from_.parent
-
- for query in from_.expressions[1:]:
- parent.join(
- query,
- join_type="CROSS",
- copy=False,
- )
- from_.expressions.remove(query)
-
- return expression
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index 5d78353..5dfa4aa 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
- alias(source.copy(), source.name or source.alias, table=True),
+ alias(source, source.name or source.alias, table=True),
copy=False,
)
.subquery(source.alias, copy=False)
diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py
deleted file mode 100644
index fae1726..0000000
--- a/sqlglot/optimizer/lower_identities.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from sqlglot import exp
-
-
-def lower_identities(expression):
- """
- Convert all unquoted identifiers to lower case.
-
- Assuming the schema is all lower case, this essentially makes identifiers case-insensitive.
-
- Example:
- >>> import sqlglot
- >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
- >>> lower_identities(expression).sql()
- 'SELECT bar.a AS A FROM "Foo".bar'
-
- Args:
- expression (sqlglot.Expression): expression to quote
- Returns:
- sqlglot.Expression: quoted expression
- """
- # We need to leave the output aliases unchanged, so the selects need special handling
- _lower_selects(expression)
-
- # These clauses can reference output aliases and also need special handling
- _lower_order(expression)
- _lower_having(expression)
-
- # We've already handled these args, so don't traverse into them
- traversed = {"expressions", "order", "having"}
-
- if isinstance(expression, exp.Subquery):
- # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1
- lower_identities(expression.this)
- traversed |= {"this"}
-
- if isinstance(expression, exp.Union):
- # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X
- lower_identities(expression.left)
- lower_identities(expression.right)
- traversed |= {"this", "expression"}
-
- for k, v in expression.iter_expressions():
- if k in traversed:
- continue
- v.transform(_lower, copy=False)
-
- return expression
-
-
-def _lower_selects(expression):
- for e in expression.expressions:
- # Leave output aliases as-is
- e.unalias().transform(_lower, copy=False)
-
-
-def _lower_order(expression):
- order = expression.args.get("order")
-
- if not order:
- return
-
- output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)}
-
- for ordered in order.expressions:
- # Don't lower references to output aliases
- if not (
- isinstance(ordered.this, exp.Column)
- and not ordered.this.table
- and ordered.this.name in output_aliases
- ):
- ordered.transform(_lower, copy=False)
-
-
-def _lower_having(expression):
- having = expression.args.get("having")
-
- if not having:
- return
-
- # Don't lower references to output aliases
- for agg in having.find_all(exp.AggFunc):
- agg.transform(_lower, copy=False)
-
-
-def _lower(node):
- if isinstance(node, exp.Identifier) and not node.quoted:
- node.set("this", node.this.lower())
- return node
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index c3467b2..f9c9664 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -13,15 +13,15 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Example:
>>> import sqlglot
- >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
- 'SELECT x.a FROM x JOIN y'
+ 'SELECT x.a FROM x CROSS JOIN y'
If `leave_tables_isolated` is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
- >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
- 'SELECT a FROM (SELECT x.a FROM x) JOIN y'
+ 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
@@ -154,7 +154,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
inner_from = inner_scope.expression.args.get("from")
if not inner_from:
return False
- inner_from_table = inner_from.expressions[0].alias_or_name
+ inner_from_table = inner_from.alias_or_name
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
return any(
col.table != inner_from_table
@@ -167,6 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
+ and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
@@ -210,7 +211,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
elif isinstance(source, exp.Table) and source.alias:
source.set("alias", new_alias)
elif isinstance(source, exp.Table):
- source.replace(exp.alias_(source.copy(), new_alias))
+ source.replace(exp.alias_(source, new_alias))
for column in inner_scope.source_columns(conflict):
column.set("table", exp.to_identifier(new_name))
@@ -228,7 +229,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
- new_subquery = inner_scope.expression.args.get("from").expressions[0]
+ new_subquery = inner_scope.expression.args["from"].this
node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)
@@ -319,7 +320,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
- sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
+ sources = {from_.alias_or_name} if from_ else {}
for join in expression.args["joins"]:
source = join.alias_or_name
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index b013312..1db094e 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -1,12 +1,12 @@
from __future__ import annotations
import logging
-import typing as t
from sqlglot import exp
from sqlglot.errors import OptimizeError
+from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
-from sqlglot.optimizer.simplify import flatten, uniq_sort
+from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
logger = logging.getLogger("sqlglot")
@@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
- cache: t.Dict[int, str] = {}
+ generate = cached_generator()
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
+ root = node is expression
+ original = node.copy()
+ node.transform(rewrite_between, copy=False)
distance = normalization_distance(node, dnf=dnf)
if distance > max_distance:
@@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
)
return expression
- root = node is expression
- original = node.copy()
try:
node = node.replace(
- while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
)
except OptimizeError as e:
logger.info(e)
@@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance, cache=None):
+def distributive_law(expression, dnf, max_distance, generate):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func, cache)
- return _distribute(b, a, from_func, to_func, cache)
+ return _distribute(a, b, from_func, to_func, generate)
+ return _distribute(b, a, from_func, to_func, generate)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func, cache)
+ return _distribute(b, a, from_func, to_func, generate)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func, cache)
+ return _distribute(a, b, from_func, to_func, generate)
return expression
-def _distribute(a, b, from_func, to_func, cache):
+def _distribute(a, b, from_func, to_func, generate):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- uniq_sort(flatten(from_func(c, b.left)), cache),
- uniq_sort(flatten(from_func(c, b.right)), cache),
+ uniq_sort(flatten(from_func(c, b.left)), generate),
+ uniq_sort(flatten(from_func(c, b.right)), generate),
copy=False,
),
)
else:
a = to_func(
- uniq_sort(flatten(from_func(a, b.left)), cache),
- uniq_sort(flatten(from_func(a, b.right)), cache),
+ uniq_sort(flatten(from_func(a, b.left)), generate),
+ uniq_sort(flatten(from_func(a, b.right)), generate),
copy=False,
)
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
new file mode 100644
index 0000000..1e5c104
--- /dev/null
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -0,0 +1,36 @@
+from sqlglot import exp
+from sqlglot._typing import E
+from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType
+
+
+def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
+ """
+ Normalize all unquoted identifiers to either lower or upper case, depending on
+ the dialect. This essentially makes those identifiers case-insensitive.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
+ >>> normalize_identifiers(expression).sql()
+ 'SELECT bar.a AS a FROM "Foo".bar'
+
+ Args:
+ expression: The expression to transform.
+ dialect: The dialect to use in order to decide how to normalize identifiers.
+
+ Returns:
+ The transformed expression.
+ """
+ return expression.transform(_normalize, dialect, copy=False)
+
+
+def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
+ if isinstance(node, exp.Identifier) and not node.quoted:
+ node.set(
+ "this",
+ node.this.upper()
+ if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
+ else node.this.lower(),
+ )
+
+ return node
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 8589657..43436cb 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,6 +1,8 @@
from sqlglot import exp
from sqlglot.helper import tsort
+JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
+
def optimize_joins(expression):
"""
@@ -45,7 +47,7 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
- head = from_.expressions[0]
+ head = from_.this
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
@@ -65,6 +67,9 @@ def normalize(expression):
Remove INNER and OUTER from joins as they are optional.
"""
for join in expression.find_all(exp.Join):
+ if not any(join.args.get(k) for k in JOIN_ATTRS):
+ join.set("kind", "CROSS")
+
if join.kind != "CROSS":
join.set("kind", None)
return expression
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index c165ffe..dbe33a2 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -10,36 +10,29 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
-from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
-from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
-from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
-from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
-from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.optimizer.qualify import qualify
+from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
RULES = (
- lower_identities,
- qualify_tables,
- isolate_table_selects,
- qualify_columns,
+ qualify,
pushdown_projections,
- validate_qualify_columns,
normalize,
unnest_subqueries,
- expand_multi_table_selects,
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
merge_subqueries,
eliminate_joins,
eliminate_ctes,
+ quote_identifiers,
annotate_types,
canonicalize,
simplify,
@@ -54,7 +47,7 @@ def optimize(
dialect: DialectType = None,
rules: t.Sequence[t.Callable] = RULES,
**kwargs,
-):
+) -> exp.Expression:
"""
Rewrite a sqlglot AST into an optimized form.
@@ -72,14 +65,23 @@ def optimize(
dialect: The dialect to parse the sql string.
rules: sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
- Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
- what you're doing!
+ Do not remove `qualify` from the sequence of rules unless you know what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
+
Returns:
- sqlglot.Expression: optimized expression
+ The optimized expression.
"""
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
- possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
+ possible_kwargs = {
+ "db": db,
+ "catalog": catalog,
+ "schema": schema,
+ "dialect": dialect,
+ "isolate_tables": True, # needed for other optimizations to perform well
+ "quote_identifiers": False, # this happens in canonicalize
+ **kwargs,
+ }
+
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
@@ -88,4 +90,5 @@ def optimize(
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)
- return expression
+
+ return t.cast(exp.Expression, expression)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index ba5c8b5..96dda33 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -21,26 +21,28 @@ def pushdown_predicates(expression):
sqlglot.Expression: optimized expression
"""
root = build_scope(expression)
- scope_ref_count = root.ref_count()
-
- for scope in reversed(list(root.traverse())):
- select = scope.expression
- where = select.args.get("where")
- if where:
- selected_sources = scope.selected_sources
- # a right join can only push down to itself and not the source FROM table
- for k, (node, source) in selected_sources.items():
- parent = node.find_ancestor(exp.Join, exp.From)
- if isinstance(parent, exp.Join) and parent.side == "RIGHT":
- selected_sources = {k: (node, source)}
- break
- pushdown(where.this, selected_sources, scope_ref_count)
-
- # joins should only pushdown into itself, not to other joins
- # so we limit the selected sources to only itself
- for join in select.args.get("joins") or []:
- name = join.this.alias_or_name
- pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
+
+ if root:
+ scope_ref_count = root.ref_count()
+
+ for scope in reversed(list(root.traverse())):
+ select = scope.expression
+ where = select.args.get("where")
+ if where:
+ selected_sources = scope.selected_sources
+ # a right join can only push down to itself and not the source FROM table
+ for k, (node, source) in selected_sources.items():
+ parent = node.find_ancestor(exp.Join, exp.From)
+ if isinstance(parent, exp.Join) and parent.side == "RIGHT":
+ selected_sources = {k: (node, source)}
+ break
+ pushdown(where.this, selected_sources, scope_ref_count)
+
+ # joins should only pushdown into itself, not to other joins
+ # so we limit the selected sources to only itself
+ for join in select.args.get("joins") or []:
+ name = join.this.alias_or_name
+ pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 2e51117..be3ddb2 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
- if scope.expression.args.get("distinct"):
- # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
+ if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
+ # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
+ # we select from a pivoted source in the parent scope.
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
@@ -105,7 +106,9 @@ def _remove_unused_selections(scope, parent_selections, schema):
for name in sorted(parent_selections):
if name not in names:
- new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
+ new_selections.append(
+ alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
+ )
# If there are no remaining selections, just select a single constant
if not new_selections:
diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py
new file mode 100644
index 0000000..5fdbde8
--- /dev/null
+++ b/sqlglot/optimizer/qualify.py
@@ -0,0 +1,80 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp
+from sqlglot.dialects.dialect import DialectType
+from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
+from sqlglot.optimizer.qualify_columns import (
+ qualify_columns as qualify_columns_func,
+ quote_identifiers as quote_identifiers_func,
+ validate_qualify_columns as validate_qualify_columns_func,
+)
+from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.schema import Schema, ensure_schema
+
+
+def qualify(
+ expression: exp.Expression,
+ dialect: DialectType = None,
+ db: t.Optional[str] = None,
+ catalog: t.Optional[str] = None,
+ schema: t.Optional[dict | Schema] = None,
+ expand_alias_refs: bool = True,
+ infer_schema: t.Optional[bool] = None,
+ isolate_tables: bool = False,
+ qualify_columns: bool = True,
+ validate_qualify_columns: bool = True,
+ quote_identifiers: bool = True,
+ identify: bool = True,
+) -> exp.Expression:
+ """
+ Rewrite sqlglot AST to have normalized and qualified tables and columns.
+
+ This step is necessary for all further SQLGlot optimizations.
+
+ Example:
+ >>> import sqlglot
+ >>> schema = {"tbl": {"col": "INT"}}
+ >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
+ >>> qualify(expression, schema=schema).sql()
+ 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"'
+
+ Args:
+ expression: Expression to qualify.
+ db: Default database name for tables.
+ catalog: Default catalog name for tables.
+ schema: Schema to infer column names and types.
+ expand_alias_refs: Whether or not to expand references to aliases.
+ infer_schema: Whether or not to infer the schema if missing.
+ isolate_tables: Whether or not to isolate table selects.
+ qualify_columns: Whether or not to qualify columns.
+ validate_qualify_columns: Whether or not to validate columns.
+ quote_identifiers: Whether or not to run the quote_identifiers step.
+ This step is necessary to ensure correctness for case sensitive queries.
+ But this flag is provided in case this step is performed at a later time.
+ identify: If True, quote all identifiers, else only necessary ones.
+
+ Returns:
+ The qualified expression.
+ """
+ schema = ensure_schema(schema, dialect=dialect)
+ expression = normalize_identifiers(expression, dialect=dialect)
+ expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
+
+ if isolate_tables:
+ expression = isolate_table_selects(expression, schema=schema)
+
+ if qualify_columns:
+ expression = qualify_columns_func(
+ expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
+ )
+
+ if quote_identifiers:
+ expression = quote_identifiers_func(expression, dialect=dialect, identify=identify)
+
+ if validate_qualify_columns:
+ validate_qualify_columns_func(expression)
+
+ return expression
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 6ac39f0..4a31171 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -1,14 +1,23 @@
+from __future__ import annotations
+
import itertools
import typing as t
from sqlglot import alias, exp
+from sqlglot._typing import E
+from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import OptimizeError
-from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
-from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.schema import ensure_schema
+from sqlglot.helper import case_sensitive, seq_get
+from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.schema import Schema, ensure_schema
-def qualify_columns(expression, schema, expand_laterals=True):
+def qualify_columns(
+ expression: exp.Expression,
+ schema: dict | Schema,
+ expand_alias_refs: bool = True,
+ infer_schema: t.Optional[bool] = None,
+) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
@@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True):
'SELECT tbl.col AS col FROM tbl'
Args:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
+ expression: expression to qualify
+ schema: Database schema
+ expand_alias_refs: whether or not to expand references to aliases
+ infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
"""
schema = ensure_schema(schema)
-
- if not schema.mapping and expand_laterals:
- expression = _expand_laterals(expression)
+ infer_schema = schema.empty if infer_schema is None else infer_schema
for scope in traverse_scope(expression):
- resolver = Resolver(scope, schema)
+ resolver = Resolver(scope, schema, infer_schema=infer_schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
using_column_tables = _expand_using(scope, resolver)
+
+ if schema.empty and expand_alias_refs:
+ _expand_alias_refs(scope, resolver)
+
_qualify_columns(scope, resolver)
+
+ if not schema.empty and expand_alias_refs:
+ _expand_alias_refs(scope, resolver)
+
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
- _expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
- if schema.mapping and expand_laterals:
- expression = _expand_laterals(expression)
-
return expression
@@ -55,9 +68,11 @@ def validate_qualify_columns(expression):
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
- if scope.external_columns and not scope.is_correlated_subquery:
+ if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
- raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
+ raise OptimizeError(
+ f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
+ )
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
@@ -142,52 +157,48 @@ def _expand_using(scope, resolver):
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
- replacement = exp.alias_(replacement, alias=column.name)
+ replacement = alias(replacement, alias=column.name, copy=False)
scope.replace(column, replacement)
return column_tables
-def _expand_alias_refs(scope, resolver):
- selects = {}
-
- # Replace references to select aliases
- def transform(node, source_first=True):
- if isinstance(node, exp.Column) and not node.table:
- table = resolver.get_table(node.name)
+def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
+ expression = scope.expression
- # Source columns get priority over select aliases
- if source_first and table:
- node.set("table", table)
- return node
+ if not isinstance(expression, exp.Select):
+ return
- if not selects:
- for s in scope.selects:
- selects[s.alias_or_name] = s
- select = selects.get(node.name)
+ alias_to_expression: t.Dict[str, exp.Expression] = {}
- if select:
- scope.clear_cache()
- if isinstance(select, exp.Alias):
- select = select.this
- return select.copy()
+ def replace_columns(
+ node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
+ ):
+ if not node:
+ return
- node.set("table", table)
- elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
- exp.replace_children(node, transform, source_first)
+ for column, *_ in walk_in_scope(node):
+ if not isinstance(column, exp.Column):
+ continue
+ table = resolver.get_table(column.name) if resolve_agg and not column.table else None
+ if table and column.find_ancestor(exp.AggFunc):
+ column.set("table", table)
+ elif expand and not column.table and column.name in alias_to_expression:
+ column.replace(alias_to_expression[column.name].copy())
- return node
+ for projection in scope.selects:
+ replace_columns(projection)
- for select in scope.expression.selects:
- transform(select)
+ if isinstance(projection, exp.Alias):
+ alias_to_expression[projection.alias] = projection.this
- for modifier, source_first in (
- ("where", True),
- ("group", True),
- ("having", False),
- ):
- transform(scope.expression.args.get(modifier), source_first=source_first)
+ replace_columns(expression.args.get("where"))
+ replace_columns(expression.args.get("group"))
+ replace_columns(expression.args.get("having"), resolve_agg=True)
+ replace_columns(expression.args.get("qualify"), resolve_agg=True)
+ replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
+ scope.clear_cache()
def _expand_group_by(scope, resolver):
@@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
+ if scope.pivots and not column.find_ancestor(exp.Pivot):
+ # If the column is under the Pivot expression, we need to qualify it
+ # using the name of the pivoted source instead of the pivot's alias
+ column.set("table", exp.to_identifier(scope.pivots[0].alias))
+ continue
+
column_table = resolver.get_table(column_name)
# column_table can be a '' because bigquery unnest has no table alias
@@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver):
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
- columns_missing_from_scope = []
-
- # Determine whether each reference in the order by clause is to a column or an alias.
- order = scope.expression.args.get("order")
-
- if order:
- for ordered in order.expressions:
- for column in ordered.find_all(exp.Column):
- if (
- not column.table
- and column.parent is not ordered
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
-
- # Determine whether each reference in the having clause is to a column or an alias.
- having = scope.expression.args.get("having")
-
- if having:
- for column in having.find_all(exp.Column):
- if (
- not column.table
- and column.find_ancestor(exp.AggFunc)
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
-
- for column in columns_missing_from_scope:
- column_table = resolver.get_table(column.name)
-
- if column_table:
- column.set("table", column_table)
+ for pivot in scope.pivots:
+ for column in pivot.find_all(exp.Column):
+ if not column.table and column.name in resolver.all_columns:
+ column_table = resolver.get_table(column.name)
+ if column_table:
+ column.set("table", column_table)
def _expand_stars(scope, resolver, using_column_tables):
@@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables):
replace_columns = {}
coalesced_columns = set()
+ # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
+ pivot_columns = None
+ pivot_output_columns = None
+ pivot = seq_get(scope.pivots, 0)
+
+ has_pivoted_source = pivot and not pivot.args.get("unpivot")
+ if has_pivoted_source:
+ pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
+
+ pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
+ if not pivot_output_columns:
+ pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
+
for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
@@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables):
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
+
columns = resolver.get_source_columns(table, only_visible=True)
if columns and "*" not in columns:
+ if has_pivoted_source:
+ implicit_columns = [col for col in columns if col not in pivot_columns]
+ new_selections.extend(
+ exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
+ for name in implicit_columns + pivot_output_columns
+ )
+ continue
+
table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
@@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables):
coalesce = [exp.column(name, table=table) for table in tables]
new_selections.append(
- exp.alias_(
- exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
+ alias(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
+ alias=name,
+ copy=False,
)
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
- column = exp.column(name, table)
- new_selections.append(alias(column, alias_) if alias_ != name else column)
+ column = exp.column(name, table=table)
+ new_selections.append(
+ alias(column, alias_, copy=False) if alias_ != name else column
+ )
else:
return
+
scope.expression.set("expressions", new_selections)
@@ -388,9 +406,6 @@ def _qualify_outputs(scope):
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
- quoted=True
- if isinstance(selection, exp.Column) and selection.this.quoted
- else None,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
@@ -400,6 +415,23 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
+def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
+ """Makes sure all identifiers that need to be quoted are quoted."""
+
+ def _quote(expression: E) -> E:
+ if isinstance(expression, exp.Identifier):
+ name = expression.this
+ expression.set(
+ "quoted",
+ identify
+ or case_sensitive(name, dialect=dialect)
+ or not exp.SAFE_IDENTIFIER_RE.match(name),
+ )
+ return expression
+
+ return expression.transform(_quote, copy=False)
+
+
class Resolver:
"""
Helper for resolving columns.
@@ -407,12 +439,13 @@ class Resolver:
This is a class so we can lazily load some things and easily share them across functions.
"""
- def __init__(self, scope, schema):
+ def __init__(self, scope, schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None
- self._unambiguous_columns = None
+ self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
self._all_columns = None
+ self._infer_schema = infer_schema
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
"""
@@ -430,7 +463,7 @@ class Resolver:
table_name = self._unambiguous_columns.get(column_name)
- if not table_name:
+ if not table_name and self._infer_schema:
sources_without_schema = tuple(
source
for source, columns in self._get_all_source_columns().items()
@@ -450,11 +483,9 @@ class Resolver:
node_alias = node.args.get("alias")
if node_alias:
- return node_alias.this
+ return exp.to_identifier(node_alias.this)
- return exp.to_identifier(
- table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
- )
+ return exp.to_identifier(table_name)
@property
def all_columns(self):
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 1b451a6..fcc5f26 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,11 +1,19 @@
import itertools
+import typing as t
from sqlglot import alias, exp
-from sqlglot.helper import csv_reader
+from sqlglot._typing import E
+from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
+from sqlglot.schema import Schema
-def qualify_tables(expression, db=None, catalog=None, schema=None):
+def qualify_tables(
+ expression: E,
+ db: t.Optional[str] = None,
+ catalog: t.Optional[str] = None,
+ schema: t.Optional[Schema] = None,
+) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.
@@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
- expression (sqlglot.Expression): expression to qualify
- db (str): Database name
- catalog (str): Catalog name
+ expression: Expression to qualify
+ db: Database name
+ catalog: Catalog name
schema: A schema to populate
Returns:
- sqlglot.Expression: qualified expression
+ The qualified expression.
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
- sequence = itertools.count()
-
- next_name = lambda: f"_q_{next(sequence)}"
+ next_alias_name = name_sequence("_q_")
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@@ -44,10 +50,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
if not derived_table.args.get("alias"):
- alias_ = f"_q_{next(sequence)}"
+ alias_ = next_alias_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
+ pivots = derived_table.args.get("pivots")
+ if pivots and not pivots[0].alias:
+ pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
+
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
@@ -59,12 +69,19 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
if not source.alias:
source = source.replace(
alias(
- source.copy(),
- name if name else next_name(),
+ source,
+ name or source.name or next_alias_name(),
+ copy=True,
table=True,
)
)
+ pivots = source.args.get("pivots")
+ if pivots and not pivots[0].alias:
+ pivots[0].set(
+ "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
+ )
+
if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
@@ -74,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
- table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
+ table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
udtf.set("alias", table_alias)
if not table_alias.name:
- table_alias.set("this", next_name())
+ table_alias.set("this", next_alias_name())
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index e00b3c9..9ffb4d6 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,4 +1,5 @@
import itertools
+import typing as t
from collections import defaultdict
from enum import Enum, auto
@@ -83,6 +84,7 @@ class Scope:
self._columns = None
self._external_columns = None
self._join_hints = None
+ self._pivots = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
@@ -261,12 +263,14 @@ class Scope:
self._columns = []
for column in columns + external_columns:
- ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
+ ancestor = column.find_ancestor(
+ exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
+ )
if (
not ancestor
- # Window functions can have an ORDER BY clause
- or not isinstance(ancestor.parent, exp.Select)
or column.table
+ or isinstance(ancestor, exp.Select)
+ or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
@@ -370,6 +374,17 @@ class Scope:
return []
return self._join_hints
+ @property
+ def pivots(self):
+ if not self._pivots:
+ self._pivots = [
+ pivot
+ for node in self.tables + self.derived_tables
+ for pivot in node.args.get("pivots") or []
+ ]
+
+ return self._pivots
+
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
@@ -463,7 +478,7 @@ class Scope:
return scope_ref_count
-def traverse_scope(expression):
+def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by it's "scopes".
@@ -488,10 +503,12 @@ def traverse_scope(expression):
Returns:
list[Scope]: scope instances
"""
+ if not isinstance(expression, exp.Unionable):
+ return []
return list(_traverse_scope(Scope(expression)))
-def build_scope(expression):
+def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
"""
Build a scope tree.
@@ -500,7 +517,10 @@ def build_scope(expression):
Returns:
Scope: root scope
"""
- return traverse_scope(expression)[-1]
+ scopes = traverse_scope(expression)
+ if scopes:
+ return scopes[-1]
+ return None
def _traverse_scope(scope):
@@ -585,7 +605,7 @@ def _traverse_tables(scope):
expressions = []
from_ = scope.expression.args.get("from")
if from_:
- expressions.extend(from_.expressions)
+ expressions.append(from_.this)
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
@@ -601,8 +621,13 @@ def _traverse_tables(scope):
source_name = expression.alias_or_name
if table_name in scope.sources:
- # This is a reference to a parent source (e.g. a CTE), not an actual table.
- sources[source_name] = scope.sources[table_name]
+ # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
+ # it is pivoted, because then we get back a new table and hence a new source.
+ pivots = expression.args.get("pivots")
+ if pivots:
+ sources[pivots[0].alias] = expression
+ else:
+ sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 0904189..e2772a0 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,11 +5,9 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
-from sqlglot.generator import Generator
+from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
-GENERATOR = Generator(normalize=True, identify="safe")
-
def simplify(expression):
"""
@@ -27,12 +25,12 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
- cache = {}
+ generate = cached_generator()
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
- node = uniq_sort(node, cache, root)
+ node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
@@ -247,7 +245,7 @@ def remove_compliments(expression, root=True):
return expression
-def uniq_sort(expression, cache=None, root=True):
+def uniq_sort(expression, generate, root=True):
"""
Uniq and sort a connector.
@@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {GENERATOR.generate(e, cache): e for e in flattened}
+ deduped = {generate(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@@ -388,14 +386,18 @@ def _simplify_binary(expression, a, b):
def simplify_parens(expression):
- if (
- isinstance(expression, exp.Paren)
- and not isinstance(expression.this, exp.Select)
- and (
- not isinstance(expression.parent, (exp.Condition, exp.Binary))
- or isinstance(expression.this, exp.Predicate)
- or not isinstance(expression.this, exp.Binary)
- )
+ if not isinstance(expression, exp.Paren):
+ return expression
+
+ this = expression.this
+ parent = expression.parent
+
+ if not isinstance(this, exp.Select) and (
+ not isinstance(parent, (exp.Condition, exp.Binary))
+ or isinstance(this, exp.Predicate)
+ or not isinstance(this, exp.Binary)
+ or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
+ or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
):
return expression.this
return expression
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index a515489..09e3f2a 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -1,6 +1,5 @@
-import itertools
-
from sqlglot import exp
+from sqlglot.helper import name_sequence
from sqlglot.optimizer.scope import ScopeType, traverse_scope
@@ -22,7 +21,7 @@ def unnest_subqueries(expression):
Returns:
sqlglot.Expression: unnested expression
"""
- sequence = itertools.count()
+ next_alias_name = name_sequence("_u_")
for scope in traverse_scope(expression):
select = scope.expression
@@ -30,19 +29,19 @@ def unnest_subqueries(expression):
if not parent:
continue
if scope.external_columns:
- decorrelate(select, parent, scope.external_columns, sequence)
+ decorrelate(select, parent, scope.external_columns, next_alias_name)
elif scope.scope_type == ScopeType.SUBQUERY:
- unnest(select, parent, sequence)
+ unnest(select, parent, next_alias_name)
return expression
-def unnest(select, parent_select, sequence):
+def unnest(select, parent_select, next_alias_name):
if len(select.selects) > 1:
return
predicate = select.find_ancestor(exp.Condition)
- alias = _alias(sequence)
+ alias = next_alias_name()
if not predicate or parent_select is not predicate.parent_select:
return
@@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence):
)
-def decorrelate(select, parent_select, external_columns, sequence):
+def decorrelate(select, parent_select, external_columns, next_alias_name):
where = select.args.get("where")
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
return
- table_alias = _alias(sequence)
+ table_alias = next_alias_name()
keys = []
# for all external columns in the where statement, find the relevant predicate
@@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
group_by.append(key)
else:
if key not in key_aliases:
- key_aliases[key] = _alias(sequence)
+ key_aliases[key] = next_alias_name()
# all predicates that are equalities must also be in the unique
# so that we don't do a many to many join
if isinstance(predicate, exp.EQ) and key not in group_by:
@@ -244,10 +243,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
)
-def _alias(sequence):
- return f"_u_{next(sequence)}"
-
-
def _replace(expression, condition):
return expression.replace(exp.condition(condition))