summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/canonicalize.py28
-rw-r--r--sqlglot/optimizer/pushdown_projections.py17
-rw-r--r--sqlglot/optimizer/qualify_columns.py22
-rw-r--r--sqlglot/optimizer/qualify_tables.py8
-rw-r--r--sqlglot/optimizer/scope.py5
5 files changed, 62 insertions, 18 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index fc37a54..c5c780d 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -1,9 +1,12 @@
+from __future__ import annotations
+
import itertools
from sqlglot import exp
+from sqlglot.helper import should_identify
-def canonicalize(expression: exp.Expression) -> exp.Expression:
+def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
@@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
Args:
expression: The expression to canonicalize.
+ identify: Whether or not to force identify identifier.
"""
- exp.replace_children(expression, canonicalize)
+ exp.replace_children(expression, canonicalize, identify=identify)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
+ expression = ensure_bool_predicates(expression)
if isinstance(expression, exp.Identifier):
- expression.set("quoted", True)
+ if should_identify(expression.this, identify):
+ expression.set("quoted", True)
return expression
@@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
return expression
+def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Connector):
+ _replace_int_predicate(expression.left)
+ _replace_int_predicate(expression.right)
+
+ elif isinstance(expression, (exp.Where, exp.Having)):
+ _replace_int_predicate(expression.this)
+
+ return expression
+
+
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if (
@@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None:
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)
+
+
+def _replace_int_predicate(expression: exp.Expression) -> None:
+ if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
+ expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 07a1b70..2e51117 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -1,7 +1,6 @@
from collections import defaultdict
from sqlglot import alias, exp
-from sqlglot.helper import flatten
from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -86,14 +85,15 @@ def _remove_unused_selections(scope, parent_selections, schema):
else:
order_refs = set()
- new_selections = defaultdict(list)
+ new_selections = []
removed = False
star = False
+
for selection in scope.selects:
name = selection.alias_or_name
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
- new_selections[name].append(selection)
+ new_selections.append(selection)
else:
if selection.is_star:
star = True
@@ -101,18 +101,17 @@ def _remove_unused_selections(scope, parent_selections, schema):
if star:
resolver = Resolver(scope, schema)
+ names = {s.alias_or_name for s in new_selections}
for name in sorted(parent_selections):
- if name not in new_selections:
- new_selections[name].append(
- alias(exp.column(name, table=resolver.get_table(name)), name)
- )
+ if name not in names:
+ new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
# If there are no remaining selections, just select a single constant
if not new_selections:
- new_selections[""].append(DEFAULT_SELECTION())
+ new_selections.append(DEFAULT_SELECTION())
- scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
+ scope.expression.select(*new_selections, append=False, copy=False)
if removed:
scope.clear_cache()
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index e793e31..66b3170 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -37,6 +37,7 @@ def qualify_columns(expression, schema):
_qualify_outputs(scope)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
+
return expression
@@ -213,6 +214,21 @@ def _qualify_columns(scope, resolver):
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", column_table)
+ elif column_table not in scope.sources:
+ # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
+ # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
+
+ root, *parts = column.parts
+
+ if root.name in scope.sources:
+ # struct is already qualified, but we still need to change the AST representation
+ column_table = root
+ root, *parts = parts
+ else:
+ column_table = resolver.get_table(root.name)
+
+ if column_table:
+ column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
# Determine whether each reference in the order by clause is to a column or an alias.
@@ -373,10 +389,14 @@ class Resolver:
if isinstance(node, exp.Subqueryable):
while node and node.alias != table_name:
node = node.parent
+
node_alias = node.args.get("alias")
if node_alias:
return node_alias.this
- return exp.to_identifier(table_name)
+
+ return exp.to_identifier(
+ table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
+ )
@property
def all_columns(self):
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 6e50182..93e1179 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -34,11 +34,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
- for source in scope.sources.values():
+ for name, source in scope.sources.items():
if isinstance(source, exp.Table):
- identifier = isinstance(source.this, exp.Identifier)
-
- if identifier:
+ if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
if not source.args.get("catalog"):
@@ -48,7 +46,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source.copy(),
- source.this if identifier else next_name(),
+ name if name else next_name(),
table=True,
)
)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 335ff3e..9c0768c 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -4,6 +4,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
+from sqlglot.helper import find_new_name
class ScopeType(Enum):
@@ -293,6 +294,8 @@ class Scope:
result = {}
for name, node in referenced_names:
+ if name in result:
+ raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
result[name] = (node, self.sources[name])
@@ -594,6 +597,8 @@ def _traverse_tables(scope):
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
sources[source_name] = scope.sources[table_name]
+ elif source_name in sources:
+ sources[find_new_name(sources, table_name)] = expression
else:
sources[source_name] = expression
continue