summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r--sqlglot/transforms.py135
1 files changed, 114 insertions, 21 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 62728d5..00f278e 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -103,7 +103,11 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr.copy(), alias), copy=False)
- expr.replace(exp.column(alias))
+ column = exp.column(alias)
+ if isinstance(expr.parent, exp.Qualify):
+ qualify_filters = column
+ else:
+ expr.replace(column)
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
@@ -133,9 +137,111 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
)
+def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
+ """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
+ if isinstance(expression, exp.Select):
+ for join in expression.args.get("joins") or []:
+ unnest = join.this
+
+ if isinstance(unnest, exp.Unnest):
+ alias = unnest.args.get("alias")
+ udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
+
+ expression.args["joins"].remove(join)
+
+ for e, column in zip(unnest.expressions, alias.columns if alias else []):
+ expression.append(
+ "laterals",
+ exp.Lateral(
+ this=udtf(this=e),
+ view=True,
+ alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
+ ),
+ )
+ return expression
+
+
+def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
+ """Convert explode/posexplode into unnest (used in hive -> presto)."""
+ if isinstance(expression, exp.Select):
+ from sqlglot.optimizer.scope import build_scope
+
+ taken_select_names = set(expression.named_selects)
+ taken_source_names = set(build_scope(expression).selected_sources)
+
+ for select in expression.selects:
+ to_replace = select
+
+ pos_alias = ""
+ explode_alias = ""
+
+ if isinstance(select, exp.Alias):
+ explode_alias = select.alias
+ select = select.this
+ elif isinstance(select, exp.Aliases):
+ pos_alias = select.aliases[0].name
+ explode_alias = select.aliases[1].name
+ select = select.this
+
+ if isinstance(select, (exp.Explode, exp.Posexplode)):
+ is_posexplode = isinstance(select, exp.Posexplode)
+
+ explode_arg = select.this
+ unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
+
+ # This ensures that we won't use [POS]EXPLODE's argument as a new selection
+ if isinstance(explode_arg, exp.Column):
+ taken_select_names.add(explode_arg.output_name)
+
+ unnest_source_alias = find_new_name(taken_source_names, "_u")
+ taken_source_names.add(unnest_source_alias)
+
+ if not explode_alias:
+ explode_alias = find_new_name(taken_select_names, "col")
+ taken_select_names.add(explode_alias)
+
+ if is_posexplode:
+ pos_alias = find_new_name(taken_select_names, "pos")
+ taken_select_names.add(pos_alias)
+
+ if is_posexplode:
+ column_names = [explode_alias, pos_alias]
+ to_replace.pop()
+ expression.select(pos_alias, explode_alias, copy=False)
+ else:
+ column_names = [explode_alias]
+ to_replace.replace(exp.column(explode_alias))
+
+ unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
+
+ if not expression.args.get("from"):
+ expression.from_(unnest, copy=False)
+ else:
+ expression.join(unnest, join_type="CROSS", copy=False)
+
+ return expression
+
+
+def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
+ """Remove table refs from columns in when statements."""
+ if isinstance(expression, exp.Merge):
+ alias = expression.this.args.get("alias")
+ targets = {expression.this.this}
+ if alias:
+ targets.add(alias.this)
+
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.name)
+ if isinstance(node, exp.Column) and node.args.get("table") in targets
+ else node,
+ copy=False,
+ )
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
- to_sql: t.Callable[[Generator, exp.Expression], str],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
@@ -143,36 +249,23 @@ def preprocess(
Args:
transforms: sequence of transform functions. These will be called in order.
- to_sql: final transform that converts the resulting expression to a SQL string.
Returns:
Function that can be used as a generator transform.
"""
- def _to_sql(self, expression):
+ def _to_sql(self, expression: exp.Expression) -> str:
expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
- return to_sql(self, expression)
+ return getattr(self, expression.key + "_sql")(expression)
return _to_sql
-def delegate(attr: str) -> t.Callable:
- """
- Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
- functions that delegate to existing generator methods.
- """
-
- def _transform(self, *args, **kwargs):
- return getattr(self, attr)(*args, **kwargs)
-
- return _transform
-
-
-UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
-ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
-ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))}
+UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
+ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
+ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
- exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
+ exp.Cast: preprocess([remove_precision_parameterized_types])
}