summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/transforms.py62
1 files changed, 49 insertions, 13 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 445fda6..03acc2b 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -141,7 +141,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
- """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
+ """Convert cross join unnest into lateral view explode."""
if isinstance(expression, exp.Select):
for join in expression.args.get("joins") or []:
unnest = join.this
@@ -166,7 +166,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
- """Convert explode/posexplode into unnest (used in hive -> presto)."""
+ """Convert explode/posexplode into unnest."""
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Select):
@@ -199,11 +199,11 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
explode_alias = ""
if isinstance(select, exp.Alias):
- explode_alias = select.alias
+ explode_alias = select.args["alias"]
alias = select
elif isinstance(select, exp.Aliases):
- pos_alias = select.aliases[0].name
- explode_alias = select.aliases[1].name
+ pos_alias = select.aliases[0]
+ explode_alias = select.aliases[1]
alias = select.replace(exp.alias_(select.this, "", copy=False))
else:
alias = select.replace(exp.alias_(select, ""))
@@ -230,9 +230,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
alias.set("alias", exp.to_identifier(explode_alias))
+ series_table_alias = series.args["alias"].this
column = exp.If(
- this=exp.column(series_alias).eq(exp.column(pos_alias)),
- true=exp.column(explode_alias),
+ this=exp.column(series_alias, table=series_table_alias).eq(
+ exp.column(pos_alias, table=unnest_source_alias)
+ ),
+ true=exp.column(explode_alias, table=unnest_source_alias),
)
explode.replace(column)
@@ -242,8 +245,10 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
expressions.insert(
expressions.index(alias) + 1,
exp.If(
- this=exp.column(series_alias).eq(exp.column(pos_alias)),
- true=exp.column(pos_alias),
+ this=exp.column(series_alias, table=series_table_alias).eq(
+ exp.column(pos_alias, table=unnest_source_alias)
+ ),
+ true=exp.column(pos_alias, table=unnest_source_alias),
).as_(pos_alias),
)
expression.set("expressions", expressions)
@@ -276,10 +281,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
size = size - 1
expression.where(
- exp.column(series_alias)
- .eq(exp.column(pos_alias))
+ exp.column(series_alias, table=series_table_alias)
+ .eq(exp.column(pos_alias, table=unnest_source_alias))
.or_(
- (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
+ (exp.column(series_alias, table=series_table_alias) > size).and_(
+ exp.column(pos_alias, table=unnest_source_alias).eq(size)
+ )
),
copy=False,
)
@@ -386,14 +393,16 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
full_outer_joins = [
(index, join)
for index, join in enumerate(expression.args.get("joins") or [])
- if join.side == "FULL" and join.kind == "OUTER"
+ if join.side == "FULL"
]
if len(full_outer_joins) == 1:
expression_copy = expression.copy()
+ expression.set("limit", None)
index, full_outer_join = full_outer_joins[0]
full_outer_join.set("side", "left")
expression_copy.args["joins"][index].set("side", "right")
+ expression_copy.args.pop("with", None) # remove CTEs from RIGHT side
return exp.union(expression, expression_copy, copy=False)
@@ -430,6 +439,33 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
return expression
+def ensure_bools(expression: exp.Expression) -> exp.Expression:
+ """Converts numeric values used in conditions into explicit boolean expressions."""
+ from sqlglot.optimizer.canonicalize import ensure_bools
+
+ def _ensure_bool(node: exp.Expression) -> None:
+ if (
+ node.is_number
+ or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
+ or (isinstance(node, exp.Column) and not node.type)
+ ):
+ node.replace(node.neq(0))
+
+ for node, *_ in expression.walk():
+ ensure_bools(node, _ensure_bool)
+
+ return expression
+
+
+def unqualify_columns(expression: exp.Expression) -> exp.Expression:
+ for column in expression.find_all(exp.Column):
+ # We only wanna pop off the table, db, catalog args
+ for part in column.parts[:-1]:
+ part.pop()
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]: