summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-10 06:44:58 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-10 06:44:58 +0000
commitbeba715b97dd2349e01dde9b077d2535680ebdca (patch)
tree0c54accb48c28eb54d2f48f88d149492717b30e5 /sqlglot/optimizer
parentReleasing debian version 11.7.1-1. (diff)
downloadsqlglot-beba715b97dd2349e01dde9b077d2535680ebdca.tar.xz
sqlglot-beba715b97dd2349e01dde9b077d2535680ebdca.zip
Merging upstream version 12.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/eliminate_joins.py2
-rw-r--r--sqlglot/optimizer/expand_laterals.py4
-rw-r--r--sqlglot/optimizer/normalize.py2
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/qualify_columns.py26
-rw-r--r--sqlglot/optimizer/qualify_tables.py18
-rw-r--r--sqlglot/optimizer/scope.py6
-rw-r--r--sqlglot/optimizer/simplify.py17
8 files changed, 59 insertions, 18 deletions
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index e0ddfa2..27de9c7 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -153,7 +153,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
- on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
for condition in on.flatten():
if isinstance(condition, exp.EQ):
diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py
index 59f3fec..5b2f706 100644
--- a/sqlglot/optimizer/expand_laterals.py
+++ b/sqlglot/optimizer/expand_laterals.py
@@ -29,6 +29,6 @@ def expand_laterals(expression: exp.Expression) -> exp.Expression:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
- if isinstance(projection, exp.Alias):
- alias_to_expression[projection.alias] = projection.this
+ if isinstance(projection, exp.Alias):
+ alias_to_expression[projection.alias] = projection.this
return expression
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 40668ef..b013312 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -152,12 +152,14 @@ def _distribute(a, b, from_func, to_func, cache):
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
+ copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
+ copy=False,
)
return a
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 62eb11e..c165ffe 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -10,7 +10,6 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
-from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@@ -30,7 +29,6 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
- expand_laterals,
pushdown_projections,
validate_qualify_columns,
normalize,
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 0a31246..6ac39f0 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -3,11 +3,12 @@ import typing as t
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
-def qualify_columns(expression, schema):
+def qualify_columns(expression, schema, expand_laterals=True):
"""
Rewrite sqlglot AST to have fully qualified columns.
@@ -26,6 +27,9 @@ def qualify_columns(expression, schema):
"""
schema = ensure_schema(schema)
+ if not schema.mapping and expand_laterals:
+ expression = _expand_laterals(expression)
+
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
@@ -39,6 +43,9 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
+ if schema.mapping and expand_laterals:
+ expression = _expand_laterals(expression)
+
return expression
@@ -124,7 +131,7 @@ def _expand_using(scope, resolver):
tables[join_table] = None
join.args.pop("using")
- join.set("on", exp.and_(*conditions))
+ join.set("on", exp.and_(*conditions, copy=False))
if column_tables:
for column in scope.columns:
@@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver):
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", column_table)
- elif column_table not in scope.sources:
+ elif column_table not in scope.sources and (
+ not scope.parent or column_table not in scope.parent.sources
+ ):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
@@ -376,10 +385,13 @@ def _qualify_outputs(scope):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias) and not selection.is_star:
- alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
- alias_.set("this", selection)
- selection = alias_
-
+ selection = alias(
+ selection,
+ alias=selection.output_name or f"_col_{i}",
+ quoted=True
+ if isinstance(selection, exp.Column) and selection.this.quoted
+ else None,
+ )
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index a719ebe..1b451a6 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -7,21 +7,29 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
- Rewrite sqlglot AST to have fully qualified tables.
+ Rewrite sqlglot AST to have fully qualified tables. Additionally, this
+ replaces "join constructs" (*) by equivalent SELECT * subqueries.
- Example:
+ Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
+ >>>
+ >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
+ >>> qualify_tables(expression).sql()
+ 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
schema: A schema to populate
+
Returns:
sqlglot.Expression: qualified expression
+
+ (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
sequence = itertools.count()
@@ -29,6 +37,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
+ # Expand join construct
+ if isinstance(derived_table, exp.Subquery):
+ unnested = derived_table.unnest()
+ if isinstance(unnested, exp.Table):
+ derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
+
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index b582eb0..e00b3c9 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -510,6 +510,9 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
+ elif isinstance(scope.expression, exp.Table):
+ # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
+ yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
@@ -587,6 +590,9 @@ def _traverse_tables(scope):
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
+ if isinstance(scope.expression, exp.Table):
+ expressions.append(scope.expression)
+
expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions:
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 4e6c910..0904189 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
+ copy=False,
)
return expression
@@ -76,9 +77,17 @@ def simplify_not(expression):
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
- return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
+ return exp.or_(
+ exp.not_(condition.left, copy=False),
+ exp.not_(condition.right, copy=False),
+ copy=False,
+ )
if isinstance(condition, exp.Or):
- return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
+ return exp.and_(
+ exp.not_(condition.left, copy=False),
+ exp.not_(condition.right, copy=False),
+ copy=False,
+ )
if is_null(condition):
return exp.null()
if always_true(expression.this):
@@ -254,12 +263,12 @@ def uniq_sort(expression, cache=None, root=True):
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
- expression = result_func(*(e for _, e in sorted(arr)))
+ expression = result_func(*(e for _, e in sorted(arr)), copy=False)
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
- expression = result_func(*deduped.values())
+ expression = result_func(*deduped.values(), copy=False)
return expression