diff options
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r-- | sqlglot/transforms.py | 135 |
1 files changed, 114 insertions, 21 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 62728d5..00f278e 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -103,7 +103,11 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: if isinstance(expr, exp.Window): alias = find_new_name(expression.named_selects, "_w") expression.select(exp.alias_(expr.copy(), alias), copy=False) - expr.replace(exp.column(alias)) + column = exp.column(alias) + if isinstance(expr.parent, exp.Qualify): + qualify_filters = column + else: + expr.replace(column) elif expr.name not in expression.named_selects: expression.select(expr.copy(), copy=False) @@ -133,9 +137,111 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr ) +def unnest_to_explode(expression: exp.Expression) -> exp.Expression: + """Convert cross join unnest into lateral view explode (used in presto -> hive).""" + if isinstance(expression, exp.Select): + for join in expression.args.get("joins") or []: + unnest = join.this + + if isinstance(unnest, exp.Unnest): + alias = unnest.args.get("alias") + udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode + + expression.args["joins"].remove(join) + + for e, column in zip(unnest.expressions, alias.columns if alias else []): + expression.append( + "laterals", + exp.Lateral( + this=udtf(this=e), + view=True, + alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore + ), + ) + return expression + + +def explode_to_unnest(expression: exp.Expression) -> exp.Expression: + """Convert explode/posexplode into unnest (used in hive -> presto).""" + if isinstance(expression, exp.Select): + from sqlglot.optimizer.scope import build_scope + + taken_select_names = set(expression.named_selects) + taken_source_names = set(build_scope(expression).selected_sources) + + for select in expression.selects: + to_replace = select + + pos_alias = "" + explode_alias = "" + + if isinstance(select, exp.Alias): + explode_alias = select.alias + select = select.this + elif isinstance(select, exp.Aliases): + pos_alias = select.aliases[0].name + explode_alias = select.aliases[1].name + select = select.this + + if isinstance(select, (exp.Explode, exp.Posexplode)): + is_posexplode = isinstance(select, exp.Posexplode) + + explode_arg = select.this + unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) + + # This ensures that we won't use [POS]EXPLODE's argument as a new selection + if isinstance(explode_arg, exp.Column): + taken_select_names.add(explode_arg.output_name) + + unnest_source_alias = find_new_name(taken_source_names, "_u") + taken_source_names.add(unnest_source_alias) + + if not explode_alias: + explode_alias = find_new_name(taken_select_names, "col") + taken_select_names.add(explode_alias) + + if is_posexplode: + pos_alias = find_new_name(taken_select_names, "pos") + taken_select_names.add(pos_alias) + + if is_posexplode: + column_names = [explode_alias, pos_alias] + to_replace.pop() + expression.select(pos_alias, explode_alias, copy=False) + else: + column_names = [explode_alias] + to_replace.replace(exp.column(explode_alias)) + + unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) + + if not expression.args.get("from"): + expression.from_(unnest, copy=False) + else: + expression.join(unnest, join_type="CROSS", copy=False) + + return expression + + +def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: + """Remove table refs from columns in when statements.""" + if isinstance(expression, exp.Merge): + alias = expression.this.args.get("alias") + targets = {expression.this.this} + if alias: + targets.add(alias.this) + + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.name) + if isinstance(node, exp.Column) and node.args.get("table") in targets + else node, + copy=False, + ) + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], - to_sql: t.Callable[[Generator, exp.Expression], str], ) -> t.Callable[[Generator, exp.Expression], str]: """ Creates a new transform by chaining a sequence of transformations and converts the resulting @@ -143,36 +249,23 @@ def preprocess( Args: transforms: sequence of transform functions. These will be called in order. - to_sql: final transform that converts the resulting expression to a SQL string. Returns: Function that can be used as a generator transform. """ - def _to_sql(self, expression): + def _to_sql(self, expression: exp.Expression) -> str: expression = transforms[0](expression.copy()) for t in transforms[1:]: expression = t(expression) - return to_sql(self, expression) + return getattr(self, expression.key + "_sql")(expression) return _to_sql -def delegate(attr: str) -> t.Callable: - """ - Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS` - functions that delegate to existing generator methods. - """ - - def _transform(self, *args, **kwargs): - return getattr(self, attr)(*args, **kwargs) - - return _transform - - -UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} -ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} -ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))} +UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])} +ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])} +ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])} REMOVE_PRECISION_PARAMETERIZED_TYPES = { - exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql")) + exp.Cast: preprocess([remove_precision_parameterized_types]) } |